import numpy as np
import scipy.sparse as sp
import torch
import torch.nn as nn
from recbole_cdr.model.crossdomain_recommender import CrossDomainRecommender
from recbole_cdr.model.cross_domain_recommender.graphcdr import GraphCDR
from recbole.model.init import xavier_normal_initialization
from recbole.model.loss import EmbLoss
from recbole.utils import InputType
import dgl
import dgl.function as fn
from dgl.nn import GraphConv


def weighted_agg_hetero(graph, user_feat, item_feat):
    """
    在异质二部图上执行 “带权 LightGCN” 的一次消息传递。
    - graph: DGL heterograph, 包含边类型:
        ('user','interact','item'), ('item','rev_interact','user')
      每条边上有 edata['w']。
    - user_feat: [num_users, d]
    - item_feat: [num_items, d]

    返回:
      new_user_feat: [num_users, d]
      new_item_feat: [num_items, d]
    """

    with graph.local_scope():
        # 1) 给节点存原始特征
        graph.nodes['user'].data['h'] = user_feat
        graph.nodes['item'].data['h'] = item_feat

        # 2) 计算 "user" 在 'interact' 边类型上的加权度: deg_user(u)= sum of w(u->x)
        graph['interact'].update_all(
            fn.copy_e('w','m'),
            fn.sum('m','deg_user')
        )
        deg_user = graph.nodes['user'].data.get('deg_user', None)
        if deg_user is None:
            deg_user = torch.zeros(user_feat.size(0), 1, device=user_feat.device)
        deg_user = deg_user.squeeze(-1).clamp(min=1e-9)
        user_norm = 1.0 / deg_user.sqrt()

        # 3) 计算 "item" 在 'rev_interact' 上的加权度: deg_item(i)= sum of w(i->x)
        graph['rev_interact'].update_all(
            fn.copy_e('w','m'),
            fn.sum('m','deg_item')
        )
        deg_item = graph.nodes['item'].data.get('deg_item', None)
        if deg_item is None:
            deg_item = torch.zeros(item_feat.size(0), 1, device=item_feat.device)
        deg_item = deg_item.squeeze(-1).clamp(min=1e-9)
        item_norm = 1.0 / deg_item.sqrt()

        # 4) 源节点先乘以自己的 node_norm
        graph.nodes['user'].data['h_norm'] = user_feat * user_norm.unsqueeze(-1)
        graph.nodes['item'].data['h_norm'] = item_feat * item_norm.unsqueeze(-1)

        # 5) 定义消息和聚合函数
        def msg_func(edges):
            # edges.src['h_norm'] shape = [E, d]
            # edges.data['w']     shape = [E, 1]
            return {'m': edges.src['h_norm'] * edges.data['w']}

        def reduce_sum(nodes):
            # mailbox['m'] shape [batch_size, #neighbors, d]
            return {'agg': nodes.mailbox['m'].sum(dim=1)}

        # 6) user->item 方向更新
        graph['interact'].update_all(msg_func, reduce_sum)
        item_agg = graph.nodes['item'].data.get('agg', None)
        if item_agg is not None:
            item_agg = item_agg * item_norm.unsqueeze(-1)
        else:
            item_agg = item_feat

        # 7) item->user 方向更新
        graph['rev_interact'].update_all(msg_func, reduce_sum)
        user_agg = graph.nodes['user'].data.get('agg', None)
        if user_agg is not None:
            user_agg = user_agg * user_norm.unsqueeze(-1)
        else:
            user_agg = user_feat

        return user_agg, item_agg


class WeightedLightGCNLayerHetero(nn.Module):
    """一层异质 LightGCN 聚合 (无可学习参数)"""
    def forward(self, graph, user_emb, item_emb):
        return weighted_agg_hetero(graph, user_emb, item_emb)

def build_hetero_bipartite_graph(
    source_u, source_i,
    target_u, target_i,
    total_num_users, total_num_items,
    edge_weights=None
):
    """
    构造一个异质二部图：
      - 节点类型： 'user' (共 total_num_users 个), 'item' (共 total_num_items 个)
      - 边类型： ('user','interact','item') 以及 ('item','rev_interact','user')
        用两条边类型来模拟“无向”二部图。
      - 去重：合并 source/target 两域交互，去掉重复 (user,item) 对。
      - 若传入 edge_weights，需自己拆分给正反两种边（若想相同即可都用同一份）。
    """
    # 1) 拼接交互 (u, i)
    all_users = torch.cat([source_u, target_u], dim=0)   # shape [E1+E2]
    all_items = torch.cat([source_i, target_i], dim=0)   # shape [E1+E2]
    # 2) 去重
    uv_pairs = torch.stack([all_users, all_items], dim=1)    # [E,2]
    uv_unique = torch.unique(uv_pairs, dim=0)                # 去除重复行
    src_users = uv_unique[:, 0]
    dst_items = uv_unique[:, 1]
    # 3) 构造异质图
    data_dict = {
        ('user', 'interact', 'item'): (src_users, dst_items),
        ('item', 'interact', 'user'): (dst_items, src_users)
    }
    g = dgl.heterograph(
        data_dict,
        num_nodes_dict={
            'user': total_num_users,
            'item': total_num_items
        }
    )
    # 4) 设置边权 w。如果没传入 edge_weights，就设置为全1
    if edge_weights is None:
        w_u2i = torch.ones(len(src_users), 1)
        w_i2u = torch.ones(len(src_users), 1)
    else:
        # 若只传一份 edge_weights，这里示例都用同一份
        w_u2i = edge_weights
        w_i2u = edge_weights
    g.edges['interact'].data['w'] = w_u2i
    g.edges['rev_interact'].data['w'] = w_i2u
    return g
class HTDGLGNN(GraphCDR):
    input_type = InputType.POINTWISE
    def __init__(self, config, dataset):
        super(HTDGLGNN, self).__init__(config, dataset)
        # load dataset info
        self.config=config
        self.SOURCE_LABEL = dataset.source_domain_dataset.label_field
        self.TARGET_LABEL = dataset.target_domain_dataset.label_field
        self.device = config['device']
        self.latent_dim = config['embedding_size']
        self.n_layers = config['n_layers']
        self.reg_weight = config['reg_weight']
        self.layers = nn.ModuleList([
            GraphConv(self.latent_dim,self.latent_dim) for _ in range(self.n_layers)
        ])

        # 建一个总的图
        #self.source_u, self.source_i,self.target_u, self.target_i
        self.dropout = nn.Dropout(p=self.drop_rate)
        self.loss = nn.BCELoss()
        self.sigmoid = nn.Sigmoid()
        self.reg_loss = EmbLoss()
        self.target_restore_user_e = None
        self.target_restore_item_e = None
        self.apply(xavier_normal_initialization)
        self.other_parameter_name = ['target_restore_user_e', 'target_restore_item_e']
        self.merge_dgl_graph=build_hetero_bipartite_graph(self.source_u,
                                                          self.source_i,
                                                          self.target_u,
                                                          self.target_i,
                                                          self.total_num_users,
                                                          self.total_num_items).to(self.device)
        self.user_embedding = torch.nn.Embedding(num_embeddings=self.total_num_users,
                                                 embedding_dim=self.latent_dim)

        self.item_embedding = torch.nn.Embedding(num_embeddings=self.total_num_items,
                                                 embedding_dim=self.latent_dim)
    def get_ego_embeddings(self):
        user_embeddings = self.user_embedding.weight
        item_embeddings = self.item_embedding.weight
        ego_embeddings = torch.cat([user_embeddings,item_embeddings], dim=0)
        return ego_embeddings
    def generate_dgl_graph(self, edge_weights=None):
        """
        构造一个带权重的 DGL 图。
        假设：
          - self.source_u, self.source_i: 源域交互 (user, item)，其中user和item的ID均从1开始
          - self.target_u, self.target_i: 目标域交互 (user, item)，ID同上
          - 原始交互保持不变，但需要在全局节点集合中额外加上编号为0的user和item，
            即使它们在交互中没有出现。
          - self.total_num_users: 应该等于 max(source_u, target_u) + 1
          - self.total_num_items: 应该等于 max(source_i, target_i) + 1
        """
        # (A) 先拼接交互数据
        all_users = torch.cat([self.target_u,self.source_u], dim=0)
        all_items = torch.cat([self.target_i,self.source_i], dim=0)

        num_users = self.total_num_users  # 应该等于 max(original user id) + 1, 包含0
        num_items = self.total_num_items  # 应该等于 max(original item id) + 1, 包含0
        all_items = all_items + num_users

        # (E) 合并边，构造图。这里只用合并后的交互即可
        num_nodes_total = num_users + num_items
        ############################
        # 3. 构造 DGL图, 设置边权
        ############################
        g = dgl.graph((all_users, all_items), num_nodes=num_nodes_total)

        # 转换为无向图
        g = dgl.to_bidirected(g)
        if edge_weights is None:
            edge_weights = torch.ones((g.num_edges(), 1), dtype=torch.float32)
        g.edata['w'] = edge_weights
        return g

    def forward(self):
        """
        执行多层 LightGCN 聚合，返回 user_all_embeddings, item_all_embeddings
        """
        # 1) 初始嵌入
        user_emb = self.user_embedding.weight  # [num_users, dim]
        item_emb = self.item_embedding.weight  # [num_items, dim]

        # 2) 各层输出
        user_all, item_all = [user_emb], [item_emb]

        for layer in self.layers:
            user_emb, item_emb = layer(self.merge_dgl_graph, (user_emb, item_emb))
            user_all.append(user_emb)
            item_all.append(item_emb)

        # 3) 对所有层输出做 mean
        final_user_emb = torch.mean(torch.stack(user_all, dim=0), dim=0)
        final_item_emb = torch.mean(torch.stack(item_all, dim=0), dim=0)
        return final_user_emb, final_item_emb

    def calculate_loss(self, interaction):
        self.init_restore_e()
        user_all_embeddings, item_all_embeddings = self.forward()
        losses = []
        source_user = interaction[self.SOURCE_USER_ID]
        source_item = interaction[self.SOURCE_ITEM_ID]
        source_label = interaction[self.SOURCE_LABEL]
        target_user = interaction[self.TARGET_USER_ID]
        target_item = interaction[self.TARGET_ITEM_ID]
        target_label = interaction[self.TARGET_LABEL]
        source_u_embeddings = user_all_embeddings[source_user]
        source_i_embeddings = item_all_embeddings[source_item]
        target_u_embeddings = user_all_embeddings[target_user]
        target_i_embeddings = item_all_embeddings[target_item]
        if self.config['setting']=='transfer':
            source_output = self.sigmoid(torch.mul(source_u_embeddings, source_i_embeddings).sum(dim=1))
            source_bce_loss = self.loss(source_output, source_label)
            # calculate Reg Loss in source domain
            u_source_ego_embeddings = self.user_embedding(source_user)
            i_source_ego_embeddings = self.item_embedding(source_item)
            source_reg_loss = self.reg_loss(u_source_ego_embeddings, i_source_ego_embeddings)
            losses = source_bce_loss + self.reg_weight * source_reg_loss
            return losses
        elif self.config['setting']=='merge':
            source_output = self.sigmoid(torch.mul(source_u_embeddings, source_i_embeddings).sum(dim=1))
            source_bce_loss = self.loss(source_output, source_label)
            u_source_ego_embeddings = self.user_embedding(source_user)
            i_source_ego_embeddings = self.item_embedding(source_item)
            source_reg_loss = self.reg_loss(u_source_ego_embeddings, i_source_ego_embeddings)
            source_loss = source_bce_loss + self.reg_weight * source_reg_loss
            losses.append(source_loss)
            # calculate BCE Loss in target domain
            target_output = self.sigmoid(torch.mul(target_u_embeddings, target_i_embeddings).sum(dim=1))
            target_bce_loss = self.loss(target_output, target_label)
            # calculate Reg Loss in target domain
            u_target_ego_embeddings = self.user_embedding(target_user)
            i_target_ego_embeddings = self.item_embedding(target_item)
            target_reg_loss = self.reg_loss(u_target_ego_embeddings, i_target_ego_embeddings)
            target_loss = target_bce_loss + self.reg_weight * target_reg_loss
            losses.append(target_loss)
            return tuple(losses)
        elif self.config['setting']=='source_passing':
            target_output = self.sigmoid(torch.mul(target_u_embeddings, target_i_embeddings).sum(dim=1))
            target_bce_loss = self.loss(target_output, target_label)
            # calculate Reg Loss in target domain
            u_target_ego_embeddings = self.target_user_embedding(target_user)
            i_target_ego_embeddings = self.target_item_embedding(target_item)
            target_reg_loss = self.reg_loss(u_target_ego_embeddings, i_target_ego_embeddings)
            losses = target_bce_loss + self.reg_weight * target_reg_loss
            return losses
    def full_sort_predict(self, interaction):
        user = interaction[self.TARGET_USER_ID]
        restore_user_e, restore_item_e = self.get_restore_e()
        u_embeddings = restore_user_e[user]
        i_embeddings = restore_item_e[:self.target_num_items]
        scores = torch.matmul(u_embeddings, i_embeddings.transpose(0, 1))
        return scores.view(-1)
    def init_restore_e(self):
        # clear the storage variable when training
        if self.target_restore_user_e is not None or self.target_restore_item_e is not None:
            self.target_restore_user_e, self.target_restore_item_e = None, None
    def get_restore_e(self):
        if self.target_restore_user_e is None or self.target_restore_item_e is None:
            self.target_restore_user_e, self.target_restore_item_e = self.forward()
        return self.target_restore_user_e, self.target_restore_item_e